{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image Segmentation with ResNet U-Net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Based on code in Divam Gupta's [image-segmentation-keras](https://github.com/divamgupta/image-segmentation-keras) repository. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras_segmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = keras_segmentation.models.unet.resnet50_unet(n_classes=51, \n",
    "                                                      input_height=416, \n",
    "                                                      input_width=608)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 22/367 [00:00<00:01, 216.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Verifying train dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 367/367 [00:01<00:00, 213.49it/s]\n",
      " 21%|██        | 21/101 [00:00<00:00, 207.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset verified! \n",
      "Verifying val dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 101/101 [00:00<00:00, 204.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset verified! \n",
      "Starting Epoch  0\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 183s - loss: 0.6805 - acc: 0.8234 - val_loss: 0.4753 - val_acc: 0.8559\n",
      "saved  model_output/image-seg-ResNet.model.0\n",
      "Finished Epoch 0\n",
      "Starting Epoch  1\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.2582 - acc: 0.9208 - val_loss: 0.2969 - val_acc: 0.9081\n",
      "saved  model_output/image-seg-ResNet.model.1\n",
      "Finished Epoch 1\n",
      "Starting Epoch  2\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.1808 - acc: 0.9416 - val_loss: 0.2757 - val_acc: 0.9178\n",
      "saved  model_output/image-seg-ResNet.model.2\n",
      "Finished Epoch 2\n",
      "Starting Epoch  3\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.1463 - acc: 0.9510 - val_loss: 0.2725 - val_acc: 0.9195\n",
      "saved  model_output/image-seg-ResNet.model.3\n",
      "Finished Epoch 3\n",
      "Starting Epoch  4\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.1222 - acc: 0.9578 - val_loss: 0.3011 - val_acc: 0.9150\n",
      "saved  model_output/image-seg-ResNet.model.4\n",
      "Finished Epoch 4\n",
      "Starting Epoch  5\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.1058 - acc: 0.9626 - val_loss: 0.3182 - val_acc: 0.9101\n",
      "saved  model_output/image-seg-ResNet.model.5\n",
      "Finished Epoch 5\n",
      "Starting Epoch  6\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0955 - acc: 0.9656 - val_loss: 0.3255 - val_acc: 0.9176\n",
      "saved  model_output/image-seg-ResNet.model.6\n",
      "Finished Epoch 6\n",
      "Starting Epoch  7\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 139s - loss: 0.0879 - acc: 0.9680 - val_loss: 0.3047 - val_acc: 0.9213\n",
      "saved  model_output/image-seg-ResNet.model.7\n",
      "Finished Epoch 7\n",
      "Starting Epoch  8\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0807 - acc: 0.9702 - val_loss: 0.3020 - val_acc: 0.9231\n",
      "saved  model_output/image-seg-ResNet.model.8\n",
      "Finished Epoch 8\n",
      "Starting Epoch  9\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0764 - acc: 0.9716 - val_loss: 0.3123 - val_acc: 0.9253\n",
      "saved  model_output/image-seg-ResNet.model.9\n",
      "Finished Epoch 9\n",
      "Starting Epoch  10\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0716 - acc: 0.9732 - val_loss: 0.3855 - val_acc: 0.9100\n",
      "saved  model_output/image-seg-ResNet.model.10\n",
      "Finished Epoch 10\n",
      "Starting Epoch  11\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 141s - loss: 0.0672 - acc: 0.9747 - val_loss: 0.3102 - val_acc: 0.9257\n",
      "saved  model_output/image-seg-ResNet.model.11\n",
      "Finished Epoch 11\n",
      "Starting Epoch  12\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0641 - acc: 0.9757 - val_loss: 0.3484 - val_acc: 0.9228\n",
      "saved  model_output/image-seg-ResNet.model.12\n",
      "Finished Epoch 12\n",
      "Starting Epoch  13\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0615 - acc: 0.9767 - val_loss: 0.3679 - val_acc: 0.9191\n",
      "saved  model_output/image-seg-ResNet.model.13\n",
      "Finished Epoch 13\n",
      "Starting Epoch  14\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0580 - acc: 0.9779 - val_loss: 0.3641 - val_acc: 0.9208\n",
      "saved  model_output/image-seg-ResNet.model.14\n",
      "Finished Epoch 14\n",
      "Starting Epoch  15\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0565 - acc: 0.9784 - val_loss: 0.3556 - val_acc: 0.9209\n",
      "saved  model_output/image-seg-ResNet.model.15\n",
      "Finished Epoch 15\n",
      "Starting Epoch  16\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 139s - loss: 0.0549 - acc: 0.9790 - val_loss: 0.3627 - val_acc: 0.9232\n",
      "saved  model_output/image-seg-ResNet.model.16\n",
      "Finished Epoch 16\n",
      "Starting Epoch  17\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0524 - acc: 0.9799 - val_loss: 0.3631 - val_acc: 0.9221\n",
      "saved  model_output/image-seg-ResNet.model.17\n",
      "Finished Epoch 17\n",
      "Starting Epoch  18\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 139s - loss: 0.0505 - acc: 0.9806 - val_loss: 0.4280 - val_acc: 0.9138\n",
      "saved  model_output/image-seg-ResNet.model.18\n",
      "Finished Epoch 18\n",
      "Starting Epoch  19\n",
      "Epoch 1/1\n",
      "512/512 [==============================] - 140s - loss: 0.0493 - acc: 0.9810 - val_loss: 0.3706 - val_acc: 0.9231\n",
      "saved  model_output/image-seg-ResNet.model.19\n",
      "Finished Epoch 19\n"
     ]
    }
   ],
   "source": [
    "model.train(\n",
    "    train_images = \"image-seg-data/images_prepped_train/\",\n",
    "    train_annotations = \"image-seg-data/annotations_prepped_train/\",\n",
    "    checkpoints_path = \"model_output/image-seg-ResNet\",\n",
    "    epochs = 20,\n",
    "    validate=True,\n",
    "    val_images = \"image-seg-data/images_prepped_test/\",\n",
    "    val_annotations = \"image-seg-data/annotations_prepped_test/\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights('model_output/image-seg-ResNet/image-seg-ResNet.3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "out = model.predict_segmentation(\n",
    "    inp=\"image-seg-data/images_prepped_test/0016E5_07965.png\",\n",
    "    out_fname=\"output-ResNet-4epochs.png\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
